import os

class Config:
    # 基础路径配置
    BASE_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
    DATA_DIR = os.path.join(BASE_DIR, 'NDataset')
    
    # 数据路径配置
    data_dir = {
        'train_data': os.path.join(DATA_DIR, 'mchar_train'),
        'val_data': os.path.join(DATA_DIR, 'mchar_val'),
        'test_data': os.path.join(DATA_DIR, 'mchar_test_a'),
        'train_label': os.path.join(DATA_DIR, 'mchar_train.json'),
        'val_label': os.path.join(DATA_DIR, 'mchar_val.json'),
        'submit_file': os.path.join(DATA_DIR, 'mchar_sample_submit_A.csv')
    }
    
    # 训练相关配置
    batch_size = 64
    lr = 1e-3
    momentum = 0.9
    weights_decay = 1e-4
    class_num = 11
    eval_interval = 1
    checkpoint_interval = 1
    print_interval = 50
    
    # 模型保存路径
    checkpoints = os.path.join(DATA_DIR, 'checkpoints')
    pretrained = os.path.join(checkpoints, 'epoch-resnet18-52-bn-acc-73.86.pth')
    
    # 训练轮数
    start_epoch = 0
    epoches = 10
    
    # 标签平滑参数
    smooth = 0.1
    erase_prob = 0.5 